{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d9d1a28b-c2a1-4246-b1c2-c58d6938f798",
   "metadata": {},
   "source": [
    "# Customer Support\n",
    "\n",
    "Here, we show an example of building a customer support chatbot.\n",
    "\n",
    "This customer support chatbot interacts with SQL database to answer questions.\n",
    "We will use a mock SQL database to get started: the [Chinook](https://www.sqlitetutorial.net/sqlite-sample-database/) database.\n",
    "This database is about sales from a music store: what songs and album exists, customer orders, things like that.\n",
    "\n",
    "This chatbot has two different states: \n",
    "1. Music: the user can inquire about different songs and albums present in the store\n",
    "2. Account: the user can ask questions about their account\n",
    "\n",
    "Under the hood, this is handled by two separate agents. \n",
    "Each has a specific prompt and tools related to their objective. \n",
    "There is also a generic agent who is responsible for routing between these two agents as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "58c8fd46-843c-4bd6-a5f7-2a50676d1b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -U scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "35abc013-2613-4a49-a806-939dcf13ccf3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: langgraph in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (0.0.24)\n",
      "Collecting langgraph\n",
      "  Downloading langgraph-0.0.26-py3-none-any.whl.metadata (34 kB)\n",
      "Collecting langchain-core<0.2.0,>=0.1.25 (from langgraph)\n",
      "  Downloading langchain_core-0.1.26-py3-none-any.whl.metadata (6.0 kB)\n",
      "Requirement already satisfied: PyYAML>=5.3 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (6.0.1)\n",
      "Requirement already satisfied: anyio<5,>=3 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (3.7.1)\n",
      "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (1.33)\n",
      "Collecting langsmith<0.2.0,>=0.1.0 (from langchain-core<0.2.0,>=0.1.25->langgraph)\n",
      "  Downloading langsmith-0.1.8-py3-none-any.whl.metadata (13 kB)\n",
      "Requirement already satisfied: packaging<24.0,>=23.2 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (23.2)\n",
      "Requirement already satisfied: pydantic<3,>=1 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (2.4.2)\n",
      "Requirement already satisfied: requests<3,>=2 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (2.31.0)\n",
      "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.25->langgraph) (8.2.3)\n",
      "Requirement already satisfied: idna>=2.8 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from anyio<5,>=3->langchain-core<0.2.0,>=0.1.25->langgraph) (3.4)\n",
      "Requirement already satisfied: sniffio>=1.1 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from anyio<5,>=3->langchain-core<0.2.0,>=0.1.25->langgraph) (1.3.0)\n",
      "Requirement already satisfied: jsonpointer>=1.9 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain-core<0.2.0,>=0.1.25->langgraph) (2.4)\n",
      "Collecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.0->langchain-core<0.2.0,>=0.1.25->langgraph)\n",
      "  Downloading orjson-3.9.15-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl.metadata (49 kB)\n",
      "\u001b[2K     \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.5/49.5 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from pydantic<3,>=1->langchain-core<0.2.0,>=0.1.25->langgraph) (0.6.0)\n",
      "Requirement already satisfied: pydantic-core==2.10.1 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from pydantic<3,>=1->langchain-core<0.2.0,>=0.1.25->langgraph) (2.10.1)\n",
      "Requirement already satisfied: typing-extensions>=4.6.1 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from pydantic<3,>=1->langchain-core<0.2.0,>=0.1.25->langgraph) (4.8.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.25->langgraph) (3.3.0)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.25->langgraph) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages (from requests<3,>=2->langchain-core<0.2.0,>=0.1.25->langgraph) (2023.7.22)\n",
      "Downloading langgraph-0.0.26-py3-none-any.whl (44 kB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.4/44.4 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading langchain_core-0.1.26-py3-none-any.whl (246 kB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m246.4/246.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hDownloading langsmith-0.1.8-py3-none-any.whl (62 kB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.2/62.2 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading orjson-3.9.15-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl (248 kB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m248.6/248.6 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: orjson, langsmith, langchain-core, langgraph\n",
      "  Attempting uninstall: langsmith\n",
      "    Found existing installation: langsmith 0.0.87\n",
      "    Uninstalling langsmith-0.0.87:\n",
      "      Successfully uninstalled langsmith-0.0.87\n",
      "  Attempting uninstall: langchain-core\n",
      "    Found existing installation: langchain-core 0.1.22\n",
      "    Uninstalling langchain-core-0.1.22:\n",
      "      Successfully uninstalled langchain-core-0.1.22\n",
      "  Attempting uninstall: langgraph\n",
      "    Found existing installation: langgraph 0.0.24\n",
      "    Uninstalling langgraph-0.0.24:\n",
      "      Successfully uninstalled langgraph-0.0.24\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "langchain-community 0.0.19 requires langsmith<0.1,>=0.0.83, but you have langsmith 0.1.8 which is incompatible.\n",
      "langchain 0.1.6 requires langsmith<0.1,>=0.0.83, but you have langsmith 0.1.8 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed langchain-core-0.1.26 langgraph-0.0.26 langsmith-0.1.8 orjson-3.9.15\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.11 -m pip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install -U langgraph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9431e7f1-07fa-49d9-ac45-29613703dcc1",
   "metadata": {},
   "source": [
    "## Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "61f7ef9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd816e9f-fc94-476d-84eb-0e2a3c370c5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/harrisonchase/.pyenv/versions/3.11.1/envs/permchain/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `get_table_names` was deprecated in LangChain 0.0.1 and will be removed in 0.2.0. Use get_usable_table_name instead.\n",
      "  warn_deprecated(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['Album',\n",
       " 'Artist',\n",
       " 'Customer',\n",
       " 'Employee',\n",
       " 'Genre',\n",
       " 'Invoice',\n",
       " 'InvoiceLine',\n",
       " 'MediaType',\n",
       " 'Playlist',\n",
       " 'PlaylistTrack',\n",
       " 'Track']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db.get_table_names()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cf668e4-8cb4-4de1-bc5e-c90284bf74bc",
   "metadata": {},
   "source": [
    "## Load an LLM\n",
    "\n",
    "We will load a language model to use.\n",
    "For this demo we will use OpenAI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d9ea4e80-30e6-4d46-b480-35f0be2fb055",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# We will set streaming=True so that we can stream tokens\n",
    "# See the streaming section for more information on this.\n",
    "model = ChatOpenAI(temperature=0, streaming=True, model=\"gpt-4-turbo-preview\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73907422-7e05-431e-b06d-256c9ec1f6f6",
   "metadata": {},
   "source": [
    "## Load Other Modules\n",
    "\n",
    "Load other modules we will use.\n",
    "\n",
    "All of the tools our agents will use will be custom tools. As such, we will use the `@tool` decorator to create custom tools.\n",
    "\n",
    "We will pass in messages to the agent, so we load `HumanMessage` and `SystemMessage`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea958e9f-ab1f-49b5-bd85-16332055297c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "from langchain_core.messages import HumanMessage, SystemMessage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35271d4d-2a1c-41be-9359-a3a7c3fed3d9",
   "metadata": {},
   "source": [
    "## Define the Customer Agent\n",
    "\n",
    "This agent is responsible for looking up customer information.\n",
    "It will have a specific prompt as well a specific tool to look up information about that customer (after asking for their user id)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "975b039a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This tool is given to the agent to look up information about a customer\n",
    "@tool\n",
    "def get_customer_info(customer_id: int):\n",
    "    \"\"\"Look up customer info given their ID. ALWAYS make sure you have the customer ID before invoking this.\"\"\"\n",
    "    return db.run(f\"SELECT * FROM Customer WHERE CustomerID = {customer_id};\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1d5fa446",
   "metadata": {},
   "outputs": [],
   "source": [
    "customer_prompt = \"\"\"Your job is to help a user update their profile.\n",
    "\n",
    "You only have certain tools you can use. These tools require specific input. If you don't know the required input, then ask the user for it.\n",
    "\n",
    "If you are unable to help the user, you can \"\"\"\n",
    "\n",
    "def get_customer_messages(messages):\n",
    "    return [SystemMessage(content=customer_prompt)] + messages\n",
    "\n",
    "customer_chain = get_customer_messages | model.bind_tools([get_customer_info])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904a9485-3857-458e-8b9d-33bc33842bc9",
   "metadata": {},
   "source": [
    "## Define the Music Agent\n",
    "\n",
    "This agent is responsible for figuring out information about music. To do that, we will create a prompt and various tools for looking up information about music\n",
    "\n",
    "First, we will create indexes for looking up artists and track names.\n",
    "This will allow us to look up artists and tracks without having to spell their names exactly right."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a8604a3b-b484-4b2b-a914-4236cb98c524",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.vectorstores import SKLearnVectorStore\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "artists = db._execute(\"select * from Artist\")\n",
    "songs = db._execute(\"select * from Track\")\n",
    "artist_retriever = SKLearnVectorStore.from_texts(\n",
    "    [a['Name'] for a in artists],\n",
    "    OpenAIEmbeddings(), \n",
    "    metadatas=artists\n",
    ").as_retriever()\n",
    "song_retriever = SKLearnVectorStore.from_texts(\n",
    "    [a['Name'] for a in songs],\n",
    "    OpenAIEmbeddings(), \n",
    "    metadatas=songs\n",
    ").as_retriever()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac7eb264-c572-4925-ad55-a1d52a18b1c0",
   "metadata": {},
   "source": [
    "First, let's create a tool for getting albums by artist."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0a2a2b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool\n",
    "def get_albums_by_artist(artist):\n",
    "    \"\"\"Get albums by an artist (or similar artists).\"\"\"\n",
    "    docs = artist_retriever.get_relevant_documents(artist)\n",
    "    artist_ids = \", \".join([str(d.metadata['ArtistId']) for d in docs])\n",
    "    return db.run(f\"SELECT Title, Name FROM Album LEFT JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Album.ArtistId in ({artist_ids});\", include_columns=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45e85066-f2fc-490e-992d-cd66c9cd6486",
   "metadata": {},
   "source": [
    "Next, lets create a tool for getting tracks by an artist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "da533f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool\n",
    "def get_tracks_by_artist(artist):\n",
    "    \"\"\"Get songs by an artist (or similar artists).\"\"\"\n",
    "    docs = artist_retriever.get_relevant_documents(artist)\n",
    "    artist_ids = \", \".join([str(d.metadata['ArtistId']) for d in docs])\n",
    "    return db.run(f\"SELECT Track.Name as SongName, Artist.Name as ArtistName FROM Album LEFT JOIN Artist ON Album.ArtistId = Artist.ArtistId LEFT JOIN Track ON Track.AlbumId = Album.AlbumId WHERE Album.ArtistId in ({artist_ids});\", include_columns=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb0e50ab-b059-427c-924b-f8072d8db23c",
   "metadata": {},
   "source": [
    "Finally, let's create a tool for looking up songs by their name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b3c07010",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool\n",
    "def check_for_songs(song_title):\n",
    "    \"\"\"Check if a song exists by its name.\"\"\"\n",
    "    return song_retriever.get_relevant_documents(song_title)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88388ff8-38b5-4e4e-a24d-de8c3670bd2b",
   "metadata": {},
   "source": [
    "Create the chain to call the relevant tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "72a14d5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "song_system_message = \"\"\"Your job is to help a customer find any songs they are looking for. \n",
    "\n",
    "You only have certain tools you can use. If a customer asks you to look something up that you don't know how, politely tell them what you can help with.\n",
    "\n",
    "When looking up artists and songs, sometimes the artist/song will not be found. In that case, the tools will return information \\\n",
    "on simliar songs and artists. This is intentional, it is not the tool messing up.\"\"\"\n",
    "def get_song_messages(messages):\n",
    "    return [SystemMessage(content=song_system_message)] + messages\n",
    "\n",
    "song_recc_chain = get_song_messages | model.bind_tools([get_albums_by_artist, get_tracks_by_artist, check_for_songs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cff15eb0-62c7-451d-a5f9-4576b24c879e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_LxRLQdYKVzGHMGgwGTIIMvBO', 'function': {'arguments': '{\"artist\":\"amy winehouse\"}', 'name': 'get_tracks_by_artist'}, 'type': 'function'}]})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "msgs = [HumanMessage(content=\"hi! can you help me find songs by amy whinehouse?\")]\n",
    "song_recc_chain.invoke(msgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a42c293-0816-4f3c-b4a3-5b9f3a0665d1",
   "metadata": {},
   "source": [
    "## Define the Generic Agent\n",
    "\n",
    "We now define a generic agent that is responsible for handling initial inquiries and routing to the right sub agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "73e74268",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import SystemMessage, HumanMessage, AIMessage\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "class Router(BaseModel):\n",
    "    \"\"\"Call this if you are able to route the user to the appropriate representative.\"\"\"\n",
    "    choice: str = Field(description=\"should be one of: music, customer\")\n",
    "\n",
    "system_message = \"\"\"Your job is to help as a customer service representative for a music store.\n",
    "\n",
    "You should interact politely with customers to try to figure out how you can help. You can help in a few ways:\n",
    "\n",
    "- Updating user information: if a customer wants to update the information in the user database. Call the router with `customer`\n",
    "- Recomending music: if a customer wants to find some music or information about music. Call the router with `music`\n",
    "\n",
    "If the user is asking or wants to ask about updating or accessing their information, send them to that route.\n",
    "If the user is asking or wants to ask about music, send them to that route.\n",
    "Otherwise, respond.\"\"\"\n",
    "def get_messages(messages):\n",
    "    return [SystemMessage(content=system_message)] + messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ddf27314",
   "metadata": {},
   "outputs": [],
   "source": [
    "chain = get_messages | model.bind_tools([Router])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3c896f34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_RpE3v45nw65Cx0RcgUXVbU8M', 'function': {'arguments': '{\"choice\":\"music\"}', 'name': 'Router'}, 'type': 'function'}]})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "msgs = [HumanMessage(content=\"hi! can you help me find a good song?\")]\n",
    "chain.invoke(msgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "40d86f59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_GoXNLA8goAQJraVW3dMF1Uze', 'function': {'arguments': '{\"choice\":\"customer\"}', 'name': 'Router'}, 'type': 'function'}]})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "msgs = [HumanMessage(content=\"hi! whats the email you have for me?\")]\n",
    "chain.invoke(msgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "bd6ddd8b-7500-46a7-811d-3bcb937bda51",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import AIMessage\n",
    "\n",
    "def add_name(message, name):\n",
    "    _dict = message.dict()\n",
    "    _dict[\"name\"] = name\n",
    "    return AIMessage(**_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "27494de5-8345-4c23-bc0e-81e0dd5d47d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END\n",
    "import json\n",
    "\n",
    "def _get_last_ai_message(messages):\n",
    "    for m in messages[::-1]:\n",
    "        if isinstance(m, AIMessage):\n",
    "            return m\n",
    "    return None\n",
    "\n",
    "\n",
    "def _is_tool_call(msg):\n",
    "    return hasattr(msg, \"additional_kwargs\") and 'tool_calls' in msg.additional_kwargs\n",
    "\n",
    "\n",
    "def _route(messages):\n",
    "    last_message = messages[-1]\n",
    "    if isinstance(last_message, AIMessage):\n",
    "        if not _is_tool_call(last_message):\n",
    "            return END\n",
    "        else:\n",
    "            if last_message.name == \"general\":\n",
    "                tool_calls = last_message.additional_kwargs['tool_calls']\n",
    "                if len(tool_calls) > 1:\n",
    "                    raise ValueError\n",
    "                tool_call = tool_calls[0]\n",
    "                return json.loads(tool_call['function']['arguments'])['choice']\n",
    "            else:\n",
    "                return \"tools\"\n",
    "    last_m = _get_last_ai_message(messages)\n",
    "    if last_m is None:\n",
    "        return \"general\"\n",
    "    if last_m.name == \"music\":\n",
    "        return \"music\"\n",
    "    elif last_m.name == \"customer\":\n",
    "        return \"customer\"\n",
    "    else:\n",
    "        return \"general\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8aec704a-46fe-4fb3-bdee-11c3bbffc370",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from langgraph.prebuilt import ToolExecutor, ToolInvocation\n",
    "\n",
    "tools = [get_albums_by_artist, get_tracks_by_artist, check_for_songs, get_customer_info]\n",
    "tool_executor = ToolExecutor(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4d5b75c6-73e0-4922-a765-a15be63f869e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _filter_out_routes(messages):\n",
    "    ms = []\n",
    "    for m in messages:\n",
    "        if _is_tool_call(m):\n",
    "            if m.name == \"general\":\n",
    "                continue\n",
    "        ms.append(m)\n",
    "    return ms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fd4dbf98-dbb3-411a-bad6-2bb334072aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "general_node = _filter_out_routes | chain | partial(add_name, name=\"general\")\n",
    "music_node = _filter_out_routes | song_recc_chain | partial(add_name, name=\"music\")\n",
    "customer_node = _filter_out_routes | customer_chain | partial(add_name, name=\"customer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9a1d7243",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import ToolMessage\n",
    "\n",
    "\n",
    "async def call_tool(messages):\n",
    "    actions = []\n",
    "    # Based on the continue condition\n",
    "    # we know the last message involves a function call\n",
    "    last_message = messages[-1]\n",
    "    for tool_call in last_message.additional_kwargs[\"tool_calls\"]:\n",
    "        function = tool_call[\"function\"]\n",
    "        function_name = function[\"name\"]\n",
    "        _tool_input = json.loads(function[\"arguments\"] or \"{}\")\n",
    "        # We construct an ToolInvocation from the function_call\n",
    "        actions.append(\n",
    "            ToolInvocation(\n",
    "                tool=function_name,\n",
    "                tool_input=_tool_input,\n",
    "            )\n",
    "        )\n",
    "    # We call the tool_executor and get back a response\n",
    "    responses = await tool_executor.abatch(actions)\n",
    "    # We use the response to create a ToolMessage\n",
    "    tool_messages = [\n",
    "        ToolMessage(\n",
    "            tool_call_id=tool_call[\"id\"],\n",
    "            content=str(response),\n",
    "            additional_kwargs={\"name\": tool_call[\"function\"][\"name\"]},\n",
    "        )\n",
    "        for tool_call, response in zip(\n",
    "            last_message.additional_kwargs[\"tool_calls\"], responses\n",
    "        )\n",
    "    ]\n",
    "    return tool_messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "dcade924",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import MessageGraph\n",
    "from langgraph.checkpoint.sqlite import SqliteSaver\n",
    "\n",
    "memory = SqliteSaver.from_conn_string(\":memory:\")\n",
    "graph = MessageGraph()\n",
    "nodes = {\"general\": \"general\", \"music\": \"music\", END: END, \"tools\": \"tools\", \"customer\": \"customer\"}\n",
    "# Define a new graph\n",
    "workflow = MessageGraph()\n",
    "workflow.add_node(\"general\", general_node)\n",
    "workflow.add_node(\"music\", music_node)\n",
    "workflow.add_node(\"customer\", customer_node)\n",
    "workflow.add_node(\"tools\", call_tool)\n",
    "workflow.add_conditional_edges(\"general\", _route, nodes)\n",
    "workflow.add_conditional_edges(\"tools\", _route, nodes)\n",
    "workflow.add_conditional_edges(\"music\", _route, nodes)\n",
    "workflow.add_conditional_edges(\"customer\", _route, nodes)\n",
    "workflow.set_conditional_entry_point(_route, nodes)\n",
    "graph = workflow.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ac65d6d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  hi!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'general':\n",
      "---\n",
      "content='Hello! How can I assist you today?' name='general'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  what songs do you have?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'general':\n",
      "---\n",
      "content='' additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_1IfQfrOHQXRuqz20GBr0GB7p', 'function': {'arguments': '{\"choice\":\"music\"}', 'name': 'Router'}, 'type': 'function'}]} name='general'\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'music':\n",
      "---\n",
      "content=\"I can help you find songs by specific artists or songs with particular titles. If you have a favorite artist or a song in mind, let me know, and I'll do my best to find information for you.\" name='music'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  anything by t swift?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'music':\n",
      "---\n",
      "content='' additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_HLdMla4vn6g3SZulKZRtNJdp', 'function': {'arguments': '{\"artist\": \"Taylor Swift\"}', 'name': 'get_albums_by_artist'}, 'type': 'function'}, {'index': 1, 'id': 'call_ZtC6LQpaieVentCQaujm5Oam', 'function': {'arguments': '{\"artist\": \"Taylor Swift\"}', 'name': 'get_tracks_by_artist'}, 'type': 'function'}]} name='music'\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'tools':\n",
      "---\n",
      "[ToolMessage(content=\"[{'Title': 'International Superhits', 'Name': 'Green Day'}, {'Title': 'American Idiot', 'Name': 'Green Day'}]\", additional_kwargs={'name': 'get_albums_by_artist'}, tool_call_id='call_HLdMla4vn6g3SZulKZRtNJdp'), ToolMessage(content='[{\\'SongName\\': \\'Maria\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Poprocks And Coke\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Longview\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Welcome To Paradise\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Basket Case\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'When I Come Around\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'She\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'J.A.R. (Jason Andrew Relva)\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Geek Stink Breath\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Brain Stew\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Jaded\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Walking Contradiction\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Stuck With Me\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \"Hitchin\\' A Ride\", \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Good Riddance (Time Of Your Life)\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Redundant\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Nice Guys Finish Last\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Minority\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Warning\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Waiting\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \"Macy\\'s Day Parade\", \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'American Idiot\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \"Jesus Of Suburbia / City Of The Damned / I Don\\'t Care / Dearly Beloved / Tales Of Another Broken Home\", \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Holiday\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Boulevard Of Broken Dreams\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Are We The Waiting\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'St. Jimmy\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Give Me Novacaine\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \"She\\'s A Rebel\", \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Extraordinary Girl\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Letterbomb\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Wake Me Up When September Ends\\', \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \"Homecoming / The Death Of St. Jimmy / East 12th St. / Nobody Likes You / Rock And Roll Girlfriend / We\\'re Coming Home Again\", \\'ArtistName\\': \\'Green Day\\'}, {\\'SongName\\': \\'Whatsername\\', \\'ArtistName\\': \\'Green Day\\'}]', additional_kwargs={'name': 'get_tracks_by_artist'}, tool_call_id='call_ZtC6LQpaieVentCQaujm5Oam')]\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'music':\n",
      "---\n",
      "content='It seems there was a mix-up in the search, and I received information related to Green Day instead of Taylor Swift. Unfortunately, I can\\'t directly access or correct this error in real-time. However, Taylor Swift has a vast discography with many popular albums and songs. Some of her well-known albums include \"Fearless,\" \"1989,\" \"Reputation,\" \"Lover,\" \"Folklore,\" and \"Evermore.\" Her music spans across various genres, including country, pop, and indie folk.\\n\\nIf you\\'re looking for specific songs or albums by Taylor Swift, please let me know, and I\\'ll do my best to provide you with the information you\\'re seeking!' name='music'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  q\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AI: Byebye\n"
     ]
    }
   ],
   "source": [
    "import uuid\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langgraph.graph.graph import START\n",
    "\n",
    "history = []\n",
    "while True:\n",
    "    user = input('User (q/Q to quit): ')\n",
    "    if user in {'q', 'Q'}:\n",
    "        print('AI: Byebye')\n",
    "        break\n",
    "    history.append(HumanMessage(content=user))\n",
    "    async for output in graph.astream(history):\n",
    "        if END in output or START in output:\n",
    "            continue\n",
    "        # stream() yields dictionaries with output keyed by node name\n",
    "        for key, value in output.items():\n",
    "            print(f\"Output from node '{key}':\")\n",
    "            print(\"---\")\n",
    "            print(value)\n",
    "        print(\"\\n---\\n\")\n",
    "    history = output[END]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "ccc495cd-9c2b-4b90-ba8b-88dd8f33966f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  hi! whats the email you have on file?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'general':\n",
      "---\n",
      "content='' additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_fA4APR8G3SIxHLN8Sfv2F6si', 'function': {'arguments': '{\"choice\":\"customer\"}', 'name': 'Router'}, 'type': 'function'}]} name='general'\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'customer':\n",
      "---\n",
      "content=\"To help you with that, I'll need your customer ID. Could you provide it, please?\" name='customer'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output from node 'customer':\n",
      "---\n",
      "content='' additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_pPUA6QaH2kQG9MRLjRn0PCDY', 'function': {'arguments': '{\"customer_id\":1}', 'name': 'get_customer_info'}, 'type': 'function'}]} name='customer'\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'tools':\n",
      "---\n",
      "[ToolMessage(content=\"[(1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.', 'Av. Brigadeiro Faria Lima, 2170', 'São José dos Campos', 'SP', 'Brazil', '12227-000', '+55 (12) 3923-5555', '+55 (12) 3923-5566', 'luisg@embraer.com.br', 3)]\", additional_kwargs={'name': 'get_customer_info'}, tool_call_id='call_pPUA6QaH2kQG9MRLjRn0PCDY')]\n",
      "\n",
      "---\n",
      "\n",
      "Output from node 'customer':\n",
      "---\n",
      "content='The email we have on file for you is luisg@embraer.com.br. Is there anything else I can assist you with?' name='customer'\n",
      "\n",
      "---\n",
      "\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "User (q/Q to quit):  q\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AI: Byebye\n"
     ]
    }
   ],
   "source": [
    "history = []\n",
    "while True:\n",
    "    user = input('User (q/Q to quit): ')\n",
    "    if user in {'q', 'Q'}:\n",
    "        print('AI: Byebye')\n",
    "        break\n",
    "    history.append(HumanMessage(content=user))\n",
    "    async for output in graph.astream(history):\n",
    "        if END in output or START in output:\n",
    "            continue\n",
    "        # stream() yields dictionaries with output keyed by node name\n",
    "        for key, value in output.items():\n",
    "            print(f\"Output from node '{key}':\")\n",
    "            print(\"---\")\n",
    "            print(value)\n",
    "        print(\"\\n---\\n\")\n",
    "    history = output[END]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
